{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "voF4T3w6mBDn"
   },
   "source": [
    "# 9.12 实战Kaggle比赛：图像分类（CIFAR-10）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2Pw3DAlujcBn"
   },
   "source": [
    "我们曾在“图像增广”一节中实验过CIFAR-10数据集。它是计算机视觉领域的一个重要数据集。现在我们将应用前面所学的知识，动手实战CIFAR-10图像分类问题的Kaggle比赛。该比赛的网页地址是 https://www.kaggle.com/c/cifar-10 。\n",
    "\n",
    "图9.16展示了该比赛的网页信息。为了便于提交结果，请先在Kaggle网站上注册账号。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dg_CR1O7FskM"
   },
   "source": [
    "![alt text](https://zh.d2l.ai/_images/kaggle_cifar10.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zs4NtA8xCn_3"
   },
   "outputs": [],
   "source": [
    "# Install TensorFlow\n",
    "try:\n",
    "  # %tensorflow_version only exists in Colab.\n",
    "  %tensorflow_version 2.x\n",
    "except Exception:\n",
    "  pass\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OIi6p-OyFUpd"
   },
   "source": [
    "## 9.12.1. 获取和整理数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "psMnKA85FYxC"
   },
   "source": [
    "比赛数据分为训练集和测试集。训练集包含5万张图像。测试集包含30万张图像，其中有1万张图像用来计分，其他29万张不计分的图像是为了防止人工标注测试集并提交标注结果。两个数据集中的图像格式都是png，高和宽均为32像素，并含有RGB三个通道（彩色）。图像一共涵盖10个类别，分别为飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。图9.16的左上角展示了数据集中部分飞机、汽车和鸟的图像。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "n5xLlNaMFxtc"
   },
   "source": [
    "## 9.12.1.1. 下载数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AHrP25nWF1z3"
   },
   "source": [
    "登录Kaggle后，可以点击图9.16所示的CIFAR-10图像分类比赛网页上的“Data”标签，并分别下载训练数据集train.7z、测试数据集test.7z和训练数据集标签trainLabels.csv。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hO6fDr5aUq4F"
   },
   "source": [
    "## 9.12.1.2. 解压数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wSxYPViXU0u2"
   },
   "source": [
    "下载完训练数据集train.7z和测试数据集test.7z后需要解压缩。解压缩后，将训练数据集、测试数据集以及训练数据集标签分别存放在以下3个路径："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "o5jXsplUU1dv"
   },
   "source": [
    "* ../data/kaggle_cifar10/train/[1-50000].png；\n",
    "* ../data/kaggle_cifar10/test/[1-300000].png；\n",
    "* ../data/kaggle_cifar10/trainLabels.csv。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "X60x8J7yU7F_"
   },
   "source": [
    "为方便快速上手，我们提供了上述数据集的小规模采样，其中train_tiny.zip包含100个训练样本，而test_tiny.zip仅包含1个测试样本。它们解压后的文件夹名称分别为train_tiny和test_tiny。此外，将训练数据集标签的压缩文件解压，并得到trainLabels.csv。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "UCgqAEPZRsOn",
    "outputId": "56f3b3d6-57ba-41fa-919e-67232d1fa3dd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ok\n",
      "ok\n",
      "ok\n"
     ]
    }
   ],
   "source": [
    "import zipfile\n",
    "def extract_zip(file_name, target_path):\n",
    "    \"\"\"unzip zip file\"\"\"\n",
    "    zip_file = zipfile.ZipFile(file_name)\n",
    "    for names in zip_file.namelist():\n",
    "        zip_file.extract(names,target_path)\n",
    "    zip_file.close()\n",
    "    print(\"ok\")\n",
    "test_tiny_name = '../../data/kaggle_cifar10/test_tiny.zip'\n",
    "train_tiny_name = '../../data/kaggle_cifar10/train_tiny.zip'\n",
    "trainLabelsCsv_name = '../../data/kaggle_cifar10/trainLabels.csv.zip'\n",
    "extract_zip(test_tiny_name, \"../../data/kaggle_cifar10\")\n",
    "extract_zip(train_tiny_name, \"../../data/kaggle_cifar10\")\n",
    "extract_zip(trainLabelsCsv_name, \"../../data/kaggle_cifar10\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nYyiFmQKSr5d"
   },
   "source": [
    "## 9.12.1.3. 整理数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sKVJuQq-StJD"
   },
   "source": [
    "我们需要整理数据集，以方便训练和测试模型。以下的read_label_file函数将用来读取训练数据集的标签文件。该函数中的参数valid_ratio是验证集样本数与原始训练集样本数之比。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8knFXl2lSj-4"
   },
   "outputs": [],
   "source": [
    "def read_label_file(data_dir, label_file, train_dir, valid_ratio):\n",
    "    with open(os.path.join(data_dir, label_file), 'r') as f:\n",
    "        # 跳过文件头行（栏名称）\n",
    "        lines = f.readlines()[1:]\n",
    "        tokens = [l.rstrip().split(',') for l in lines]\n",
    "        idx_label = dict(((int(idx), label) for idx, label in tokens))\n",
    "    labels = set(idx_label.values())\n",
    "    n_train_valid = len(os.listdir(os.path.join(data_dir, train_dir)))\n",
    "    n_train = int(n_train_valid * (1 - valid_ratio))\n",
    "    assert 0 < n_train < n_train_valid\n",
    "    return n_train // len(labels), idx_label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AxaIt75uVEvW"
   },
   "source": [
    "下面定义一个辅助函数，从而仅在路径不存在的情况下创建路径。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lmVzx44uU-xH"
   },
   "outputs": [],
   "source": [
    "# 本函数已保存在d2lzh包中方便以后使用\n",
    "def mkdir_if_not_exist(path):\n",
    "    if not os.path.exists(os.path.join(*path)):\n",
    "        os.makedirs(os.path.join(*path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lHJty28Gkyth"
   },
   "source": [
    "我们接下来定义reorg_train_valid函数来从原始训练集中切分出验证集。以valid_ratio=0.1为例，由于原始训练集有50,000张图像，调参时将有45,000张图像用于训练并存放在路径input_dir/train下，而另外5,000张图像将作为验证集并存放在路径input_dir/valid下。经过整理后，同一类图像将被放在同一个文件夹下，便于稍后读取。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RPdBD1aEVZRx"
   },
   "outputs": [],
   "source": [
    "def reorg_train_valid(data_dir, train_dir, input_dir,\n",
    "            n_train_per_label, idx_label):\n",
    "    label_count = {}\n",
    "    for train_file in os.listdir(os.path.join(data_dir, train_dir)):\n",
    "        idx = int(train_file.split('.')[0])\n",
    "        label = idx_label[idx]\n",
    "        mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label])\n",
    "        shutil.copy(os.path.join(data_dir, train_dir, train_file),\n",
    "            os.path.join(data_dir, input_dir, \"train_valid\", label))\n",
    "        if label not in label_count or label_count[label] < n_train_per_label:\n",
    "            mkdir_if_not_exist([data_dir, input_dir, 'train', label])\n",
    "            shutil.copy(os.path.join(data_dir, train_dir, train_file),\n",
    "                os.path.join(data_dir, input_dir, 'train', label))\n",
    "            # 如果不存在则返回 0 + 1\n",
    "            label_count[label] = label_count.get(label, 0) + 1\n",
    "        else:\n",
    "            mkdir_if_not_exist([data_dir, input_dir, 'valid', label])\n",
    "            shutil.copy(os.path.join(data_dir, train_dir, train_file),\n",
    "                os.path.join(data_dir, input_dir, 'valid', label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "o8CFMbDIEIx5"
   },
   "source": [
    "下面的reorg_test函数用来整理测试集，从而方便预测时的读取。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "FNr7DtI4CDW9"
   },
   "outputs": [],
   "source": [
    "def reorg_test(data_dir, test_dir, input_dir):\n",
    "    mkdir_if_not_exist([data_dir, input_dir, 'test', 'unknown'])\n",
    "    for test_file in os.listdir(os.path.join(data_dir, test_dir)):\n",
    "        shutil.copy(os.path.join(data_dir, test_dir, test_file),\n",
    "            os.path.join(data_dir, input_dir, 'test', 'unknown'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Nw-kPJE5FNUw"
   },
   "source": [
    "最后，我们用一个函数分别调用前面定义的read_label_file函数、reorg_train_valid函数以及reorg_test函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SuZFK5RqFJwJ"
   },
   "outputs": [],
   "source": [
    "def reorg_cifar10_data(data_dir, label_file, train_dir, test_dir,\n",
    "            input_dir, valid_ratio):\n",
    "    n_train_per_label, idx_label = read_label_file(data_dir, label_file,\n",
    "                            train_dir, valid_ratio)\n",
    "    reorg_train_valid(data_dir, train_dir, input_dir, n_train_per_label,\n",
    "            idx_label)\n",
    "    reorg_test(data_dir, test_dir, input_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NzIp7otGGYCH"
   },
   "source": [
    "我们在这里只使用100个训练样本和1个测试样本。训练数据集和测试数据集的文件夹名称分别为train_tiny和test_tiny。相应地，我们仅将批量大小设为1。实际训练和测试时应使用Kaggle比赛的完整数据集，并将批量大小batch_size设为一个较大的整数，如128。我们将10%的训练样本作为调参使用的验证集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BiWMtja9GVpy"
   },
   "outputs": [],
   "source": [
    "demo = True\n",
    "\n",
    "if demo:\n",
    "    # 注意，此处使用小训练集和小测试集并将批量大小相应设小。使用Kaggle比赛的完整数据集时可设批量大小为较大整数\n",
    "    train_dir, test_dir, batch_size = 'train_tiny', 'test_tiny', 1\n",
    "else:\n",
    "    train_dir, test_dir, batch_size = 'train', 'test', 128\n",
    "\n",
    "data_dir = '../../data/kaggle_cifar10'\n",
    "label_file = 'trainLabels.csv'\n",
    "input_dir = 'train_valid_test'\n",
    "valid_ratio = 0.1\n",
    "\n",
    "reorg_cifar10_data(data_dir, label_file, train_dir, test_dir,\n",
    "            input_dir, valid_ratio)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TekJFvqQHXTC"
   },
   "source": [
    "## 9.12.2. 图像增广"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FIKkTojGHkr-"
   },
   "source": [
    "为应对过拟合，我们使用图像增广。例如，加入tf.image.random_flip_left_right()即可随机对图像做镜面翻转，也可以通过tf.image.per_image_standardization对彩色图像RGB三个通道分别做标准化。下面列举了其中的部分操作，你可以根据需求来决定是否使用或修改这些操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "S-RsjynDHQp1"
   },
   "outputs": [],
   "source": [
    "def transform_train(feature, label):\n",
    "    # 将图像放大成高和宽各为40像素的正方形\n",
    "    feature = tf.image.resize(feature, [40, 40])\n",
    "    # 随机对高和宽各为40像素的正方形图像裁剪出面积为原图像面积0.64~1倍的小正方形，\n",
    "    # 再放缩为高和宽各为32像素的正方形\n",
    "    seed = tf.random.uniform(shape=(1,), minval=0.64*40, maxval=40)\n",
    "    seed = tf.cast(seed[0], tf.int8)\n",
    "    feature = tf.image.random_crop(feature, size=[seed, seed, 3])\n",
    "    feature = tf.image.resize(feature, size=[32, 32])\n",
    "    feature = tf.image.random_flip_left_right(feature)\n",
    "    feature = tf.image.random_flip_up_down(feature)\n",
    "    # 已经是0-1了\n",
    "    # feature = tf.divide(feature, 255.)\n",
    "    # feature = tf.image.per_image_standardization(feature)\n",
    "    mean = tf.convert_to_tensor([0.4914, 0.4822, 0.4465])\n",
    "    std = tf.convert_to_tensor([0.2023, 0.1994, 0.2010])\n",
    "    # feature = tf.divide(tf.subtract(feature, mean), std)\n",
    "    return feature, label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hQR5IgeBa0WA"
   },
   "source": [
    "测试时，为保证输出的确定性，我们仅对图像做标准化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Y1jyRP_jaTQM"
   },
   "outputs": [],
   "source": [
    "def transform_test(feature, label):\n",
    "    feature = tf.divide(feature, 255.)\n",
    "    # feature = tf.image.per_image_standardization(feature)\n",
    "    mean = tf.convert_to_tensor([0.4914, 0.4822, 0.4465])\n",
    "    std = tf.convert_to_tensor([0.2023, 0.1994, 0.2010])\n",
    "    feature = tf.divide(tf.subtract(feature, mean), std)\n",
    "    return feature, label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f2wAPlDtbDcB"
   },
   "source": [
    "## 9.12.3. 读取数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GTugOIFKbEpY"
   },
   "source": [
    "接下来，可以通过创建ImageFolderDataset实例来读取整理后的含原始图像文件的数据集，其中每个数据样本包括图像和标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TT7dsVpFbBSy"
   },
   "outputs": [],
   "source": [
    "# 获取文件名\n",
    "train_list_ds = tf.data.Dataset.list_files(os.path.join(data_dir, input_dir, 'train', '*/*'))\n",
    "valid_list_ds = tf.data.Dataset.list_files(os.path.join(data_dir, input_dir, 'valid', '*/*'))\n",
    "train_valid_list_ds = tf.data.Dataset.list_files(os.path.join(data_dir, input_dir, 'train_valid', '*/*'))\n",
    "test_ds = tf.data.Dataset.list_files(os.path.join(data_dir, input_dir, 'test', '*/*'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "y0kfH7MYP-Pc"
   },
   "outputs": [],
   "source": [
    "def idx_and_label(data_dir, label_file, train_dir):\n",
    "    with open(os.path.join(data_dir, label_file), 'r') as f:\n",
    "        # 跳过文件头行（栏名称）\n",
    "        lines = f.readlines()[1:]\n",
    "        tokens = [l.rstrip().split(',') for l in lines]\n",
    "        idx_label = dict(((int(idx), label) for idx, label in tokens))\n",
    "    labels = set(idx_label.values())\n",
    "    idx = []\n",
    "    label = []\n",
    "    for i, l in enumerate(labels):\n",
    "        idx.append(i)\n",
    "        label.append(l)\n",
    "    return idx, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IRtFi0xEyTXZ"
   },
   "outputs": [],
   "source": [
    "idx, label = idx_and_label(data_dir, label_file, train_dir)\n",
    "# 先建立个字典\n",
    "label = tf.convert_to_tensor(label)\n",
    "idx = tf.convert_to_tensor(idx)\n",
    "label2idx = tf.lookup.StaticHashTable(\n",
    "            tf.lookup.KeyValueTensorInitializer(\n",
    "                label, idx, key_dtype=tf.string, value_dtype=tf.int32), -1)\n",
    "idx2label = tf.lookup.StaticHashTable(\n",
    "            tf.lookup.KeyValueTensorInitializer(\n",
    "                idx, label, key_dtype=tf.int32, value_dtype=tf.string), \"unknown\")\n",
    "\n",
    "# label2idx.lookup(tf.convert_to_tensor('cat'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iAFQUjWtiQls"
   },
   "outputs": [],
   "source": [
    "def parse_image(filename):\n",
    "    # 通过 filename 得到 feature 和 label\n",
    "    parts = tf.strings.split(filename, '/')\n",
    "    label = parts[-2]\n",
    "    label = label2idx.lookup(label)\n",
    "\n",
    "    image = tf.io.read_file(filename)\n",
    "    image = tf.image.decode_jpeg(image)\n",
    "    image = tf.image.convert_image_dtype(image, tf.float32)\n",
    "    return image, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Cw8Fa2UngzhS"
   },
   "outputs": [],
   "source": [
    "train_ds = train_list_ds.map(parse_image)\n",
    "valid_ds = valid_list_ds.map(parse_image)\n",
    "train_valid_ds = train_valid_list_ds.map(parse_image)\n",
    "test_ds = test_ds.map(parse_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 300
    },
    "colab_type": "code",
    "id": "nnncIUwmipXT",
    "outputId": "ac11500d-346d-424d-9907-02ef7cf51420"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(b'frog', shape=(), dtype=string)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f5e9fc4d208>"
      ]
     },
     "execution_count": 18,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAfMklEQVR4nO2da2yc53Xn/2dunOGdFC+SKNmy5UvtNLbiqIbXyXaTBi3coKgTYJFNPgT+EFRF0QAN0P1gZIFNFtgPyWKTIB8WWSgbt+4im8vm0hiFsW1qpDDaFK7l2PG9tizLkSiKokRS5HCGcz37YcZb2fv8H9IiOVTy/H+AoOF7+LzvmWfe877zPn+ec8zdIYT41Sez2w4IIXqDgl2IRFCwC5EICnYhEkHBLkQiKNiFSITcVgab2X0AvgogC+B/uPsXYr+fz+e9r1gM2lqtFh2XQVgezBo/ViHHr2P5iC2XzVKbWfiAZpFrZsTHZpO/55ggmo35SKTUtrf5sdr8aJaJvIEI7Xb4vcV8j+4v4r9FJpnZMhE/shn+ebJzAADaERnbYycCGxPdX5jF5VWUK+vBg111sJtZFsB/A/DbAM4CeNLMHnH3F9mYvmIRR+56b9C2vLxIj9WXCX/Q4wU+Gdft6ae2yfEBapsYHaS2QjYf3J7rK9ExyPIpXlxaprZ6k7+3sdERasu0GsHttVqNjllfX6e2Yil8cQaAFvjFqlItB7ePjA7TMXC+v3qtTm1ZhD8XgF9chgb55zwwwM+PfJ7PRzXio8duCJnwORJ7z00PXzy++I3v88NwDzbkbgAn3f2Uu9cBfBvA/VvYnxBiB9lKsM8AOHPFz2e724QQ1yBbembfDGZ2DMAxAOjr69vpwwkhCFu5s88COHjFzwe6296Cux9396PufjSX589WQoidZSvB/iSAm83sBjMrAPg4gEe2xy0hxHZz1V/j3b1pZp8G8NfoSG8PufsLsTHr6+t44cXwryxfvEjHjZMFUNvDV0YnWkPUZqUpaltrc1Wg3AqvkLsV6JjKOl9RrVT5CnmjxaWmixHNsZgL+9hs8v1lyWowEH/0qqyvUVuzHX7ftr6HjslEVLlGRE0o5fh5UCYr2outJh3T389X4y3Dv50aUWsAABE5r7IeVlCajfB2AMjmwp9LY71Kx2zpmd3dHwXw6Fb2IYToDfoLOiESQcEuRCIo2IVIBAW7EImgYBciEXb8L+iuJAOglCOyUeSP664nEtuhaZ4QMjU5Tm2lmLQSyWqq1sIJI+sNLgt5ZH+FUiSBJpII421+vJHxcAJQs8H3V8hzPyLJiMgW+IdWq4fnqtHk89Ef2V9ugPtYjIxrWlgezESy6JqRDLVYpuXgAE++Kq9VqK3RDEtssYTD1ZXLwe3taPaoECIJFOxCJIKCXYhEULALkQgKdiESoaer8WaOooUTEIaGuCu3zIwFt+8p8cyJfJuXWiov8uSUVptf/6qVsO8ZngeD4UiZq1xkFXn58iofF/nUxofCK8KrKzxppR5JaKmSJA0gXldtkJR2atR5okamxd9YPpKQ0yKluAAgR5bPazU+ppDnH2imzRNoauUlagNJogKAPnIaN9tcMbi8FlZkWpF6grqzC5EICnYhEkHBLkQiKNiFSAQFuxCJoGAXIhF6Kr3lzDDWFz5kKSKtjJAkiMlhXvOrRdoPAYj0MQGyuUghNFJHrNaOSD8RnSwXScZo1bhE5Vl+jb5wIdxlptXg73q1wpM0Ki0uUw6WIt1daqT9E/h7zhiXjbJ9kU4sa1xm7c+HfcxFWiutR+oGVhtcemtHmnYtl7mPy5Xw+VMmUi8ArDfC50A9UmtQd3YhEkHBLkQiKNiFSAQFuxCJoGAXIhEU7EIkwpakNzM7DWAVHTWr6e5HowfLGiZHwxLKUJ5LXsVi2JbJcqmjFKnv1mhyGaodyeTqtKH//6lH6sW16lyWa3skoywieXmOZ2Wt1sMZbK0Wn99KpNVUM2JbXeP+zy6G/chn+P6Gy3zuG+d5e7DqZS4dXjdxU3D71NQBOsaGwvXdAKC2dInaymWePXh5lUtvFy+HZdbTZ7gfrWw4dGt1Ltdth87+QXfnn4QQ4ppAX+OFSIStBrsD+Bsze8rMjm2HQ0KInWGrX+Pf7+6zZjYF4Mdm9rK7P37lL3QvAscAoBh5LhdC7CxburO7+2z3/wsAfgjg7sDvHHf3o+5+tJDTU4MQu8VVR5+ZDZjZ0JuvAfwOgOe3yzEhxPayla/x0wB+2G2XlAPwv9z9/8QG5HNZ7J8MFyIcLnDJYLA/LDVZRLpCJAPJItlmtSqXcTJEltszxNtQDQzwbK2Vy1zEGBnmGWWrkSKQb8yG91mu8UeoAp8OzPRHsvbyPDPv9KVw9l3NI0VCI1lvI8ND1Hbv7VzxXZkLy6xeiRxrgmdT1ip8Psplfu/sy/N9Htwbfm9TU9N0zPxKWMq79Mp5Ouaqg93dTwG482rHCyF6ix6ihUgEBbsQiaBgFyIRFOxCJIKCXYhE6G3ByaxhfCicjZarh6UaAOjLh93s7wv3NQOAWpXLU41Iv67R0XBfOQBwUqSw3uLXzEYjUgxxkPeBO7cQ7uUFAK+9wbOhFlbD7y1SuxDXR3rmfeRfH6G2A/u4/9976lRw+z+e5NJQs80z/XIZLpWtLi9QW6UcnsehIS6FocWz74pFPq5AsjMBoN/4uGYr/OFcd3A/HTO0GO4F+OzrfC50ZxciERTsQiSCgl2IRFCwC5EICnYhEqG3q/G5HKbG9wRt1UW+ap2xsJtl0jYHAKqxWlwWqccWaZPErozVBl9FHh3jCS31Fl9hPnX2HLUtrnAfWX26bKRl1HCR728qF171BYDiIlcMbh7eG9w+N879mF++QG21Cp/jp195hdoypB1SYyDSumqEJ6Agw0NmZISrQ0PtSLspUqfQ6yt0zCGSUNaX5/OrO7sQiaBgFyIRFOxCJIKCXYhEULALkQgKdiESocfSWx5jE5NB29ggb9eUyYSTCJZXluiYxlqZ768Va//EC7I5ScgZHOR15hrgtpdOcclorcZbCRWLfdxWCPtYGuCy0FiWy5RPnZyntmadnz61kbD0NjnG58PA5bBGk0uzlTqvhbdGas3Vm/w9W0RKjXQHQz4TaR2WidTey4XnsVnj0qYT2ZbkagHQnV2IZFCwC5EICnYhEkHBLkQiKNiFSAQFuxCJsKH0ZmYPAfg9ABfc/de728YBfAfAIQCnAXzM3bkO9i97A4iMZpH2OIy+SD2wfoSzggAgF7nGZTKRenJElusr8fZPF8/zrLHKRT5lN45ziarGVSgUicR26+EZOiYT2WEzy+d4JSJ95rLhOnlDBf657Bk7TG2Hb76O2l7/xZPU9vIrs8HthVxE1nIu2zabPGQyJOMQAPIFPo/tdvi8akd0PrPweRpRBjd1Z/9zAPe9bduDAB5z95sBPNb9WQhxDbNhsHf7rS++bfP9AB7uvn4YwEe22S8hxDZztc/s0+4+1319Hp2OrkKIa5gtL9B5p5g6/SM9MztmZifM7MRqJfKwKYTYUa422OfNbB8AdP+n9YTc/bi7H3X3o0P9fNFJCLGzXG2wPwLgge7rBwD8aHvcEULsFJuR3r4F4AMAJszsLIDPAfgCgO+a2acAvAHgY5s5WNsd1fVwcT1r8MwlIJyhtLbGC/LVG/w61szwbxjlCpfKVoht5iCfRm/y/V0/wYWSw/u5VFNZ5+NmbrkzuL3g/BFq6TIv3FkaDRcIBQBc4plcB/fuC25fXuPZfDf+2s3UNjzGs/aGx26jtqWF8PwvXeYttPIReTDjPOOw0Y5kU/JkSrQa4fM7kkRHW5FFkt42DnZ3/wQxfWijsUKIawf9BZ0QiaBgFyIRFOxCJIKCXYhEULALkQg9LTjpcLQsLE94ixcAZDJDqciLVA4Ocanm3AKX+V4/u0BtuXzYj8I878u2Ps/3d/MUl9c+9AEuQ702+/ZUhX9haCZc0HNiT7gAJABcWOBFJUdHIzJUm/tfIAUWLyyEs9AAIFdcpraF5Tlqm53jWWr5fPg8GB3mWli1ygUsz/H7o0W0snZElstYeJxFMjAjbQL5cd75ECHELyMKdiESQcEuRCIo2IVIBAW7EImgYBciEXoqvWWzGYyODgZtzRyX3srlcMaWN7iccXmVZzW98QsuNZXLXMYpFcPXxrnXefbddJEXIZyZuZ7aRvffQG351UgKFSnCeeDOu/mQ81wOKzW5dNgCz6RbWwvb9vWHpUEAqLf4+7KB8HkDAAcG9lPb0GhYcly9dJ6OuTB/idoaxuXG9TovYokM18oG+sJZmPVqRFIkBSyNyHiA7uxCJIOCXYhEULALkQgKdiESQcEuRCL0dDW+3WpidTm80pmr81ptedLqBrwEGnJZbqyU+Ur92BBP/BgdCK+aVpf4avzUfl7DbeaOf0Ntz5+tU9srJ7nt3n3jwe3Ly3zM9OFw3ToAyKBCbfUaX6kf9fDK+soFvtJdqvNaePvGw+8LAJZbvC5c/o6x4PZqJLHmHx59hNrOnuHvORtp8RRrzMTybhqxNmWN8FyxpDFAd3YhkkHBLkQiKNiFSAQFuxCJoGAXIhEU7EIkwmbaPz0E4PcAXHD3X+9u+zyAPwDwpg7xWXd/dDMHzBIFohX5o38nskWGtIUCgJZx6W2JKzxYWYnUH6uF5at9I1yu+40PfpDaDtx6D7X94M8eora9kaSQbD1cX2/21Gt8fzfeTm3FPTdR24BzubSyGO71WWqHpTAAqFe5zHdxldtGJ3nS0J69h4Lbq+VhOibDTWgVePJPrAZdo8GlT2uGE7rMeaJXsxkO3a1Kb38O4L7A9q+4+5Huv00FuhBi99gw2N39cQC8nKkQ4peCrTyzf9rMnjWzh8yMfzcTQlwTXG2wfw3AYQBHAMwB+BL7RTM7ZmYnzOxEucKfW4QQO8tVBbu7z7t7y93bAL4OgJZBcffj7n7U3Y8O9vOqLUKIneWqgt3M9l3x40cBPL897gghdorNSG/fAvABABNmdhbA5wB8wMyOAHAApwH84WYOZgCMKAMtksUD8DY4kU488Gpkf5ESbuN7eNuovf1hqe+uo7fQMbfdy+W1pQtcbuxr8sy8Gw8coLY2eXN7p3jtt+Y6lzArkWy5epOPa1TDp1YLXDZ8bfYstT33/Alqu/ce7uOeveGsw5XVsDQIAKRjFABg4hCXWduxdk31iIxGJN3LC7wdVm017GSbZBsCmwh2d/9EYPM3NhonhLi20F/QCZEICnYhEkHBLkQiKNiFSAQFuxCJ0NOCk+5Am2T4VGtcMiiQLK9cjhf4y2a4HHPTXv7XvcUSv/4duv5gcPud7+eZbftuvYPanvnHP6O26w5yH/e+693UVpg8HNye6x+hYyrrXAKsrvDMtvlzZ6htaT4so7UaPHutNBQu6AkAExP8sz5z7mlqm943E9zerESyLKu8jZOtLVFby8MZhwDgTHMGUOoLv7fCXv6eV/pIJmgkonVnFyIRFOxCJIKCXYhEULALkQgKdiESQcEuRCL0VHozM+Sz4UMuRQoKttbDMkOpv0THZDNc6piKZLadmeOZRofvCpXiAw68O7y9A5fQGqtr1DYyxKWyyVuOUNtaLtwT7YWnn6RjalXux8oKn4+Ls7+gtmwrLH0Wi/yUm7khLJMBwB238MKXzSzPRMtnR8PbCzwrMrfOi0pW3pilNiYrA0Azclstk76E/Xv4+5omPQTz+Uh/OO6CEOJXCQW7EImgYBciERTsQiSCgl2IROhtIky7jVo1vNLZ38ddsWJ4tTKf4TXQvMVtpUHeGur3/93vU9u9v/uh4PbhiWk6Zv7US9SWjfi/vMpr0C2c/mdqO7caXhH+u7/8SzpmsMQTLtZrPGFk7zRXDIaHwivJr5/lyTP1yHyM7z9Ebbe8+73UhlZfcPPiMq93VyHqDwAsVbmP5vwcXq/yRK8yadnkZa4K3BYWGdDmIpTu7EKkgoJdiERQsAuRCAp2IRJBwS5EIijYhUiEzbR/OgjgLwBMo9Pu6bi7f9XMxgF8B8AhdFpAfczdeYEuAA5H20ltuDZPIrBmWLZoeqTFU6TmV7FvmNqOvJfLOH35sET14jO8BtrSudeorVbj0srq0iK1nTn5IrWVPZwclG/xYw3muBQ5XOTJGJNjXHqbmz8f3N6MtPmqrHKZ78zrPOkGeIFayuVwDb1ijp8fzb4parvU5OdOqcRr6PUP8aStUi4sD65WVuiYZjssAUaUt03d2ZsA/tTdbwdwD4A/NrPbATwI4DF3vxnAY92fhRDXKBsGu7vPufvPuq9XAbwEYAbA/QAe7v7awwA+slNOCiG2zjt6ZjezQwDeA+AJANPuPtc1nUfna74Q4hpl08FuZoMAvg/gM+7+locJd3eQxwUzO2ZmJ8zsxFqV13IXQuwsmwp2M8ujE+jfdPcfdDfPm9m+rn0fgGDDa3c/7u5H3f3oQKmwHT4LIa6CDYPdzAydfuwvufuXrzA9AuCB7usHAPxo+90TQmwXm8l6ex+ATwJ4zsye6W77LIAvAPiumX0KwBsAPrbxrhxAWEZrN/lX/Fw+XDOuFan5VQfPTpoe4XXh/vqRv6K28emwxDO1L9wWCgDqFZ69ls+HJRcAGBzgEk8uw6WyASIP7p0K1ywDgOoqV0xLWe7jpYWL1Naohz+boSKXoOplLr29+vQJapt7+RVqqzVJS6Y8n8NWbH4PcCkSA/wczvRx6bNIZLQx8Lm67V03BLeXiqfomA2D3d3/HgDL+QvnfAohrjn0F3RCJIKCXYhEULALkQgKdiESQcEuRCL0tOAk3NBuhxf2C5HMq2KOFOvL8MKAHmkJ1K7zzKuLF8PZWgBQXgjbSg2endQGf1/jY1wOG90/SW3NVo3aZs+FffRIPlQmw0+DepNLmFnjhSoHimG5lCQwdvYXM0ayGFt1Lm9myPm2UuFyY72PyHUAhvbzuV8r8VZZq20uy62vhe+5e4ZvpGMmiJSay/PPUnd2IRJBwS5EIijYhUgEBbsQiaBgFyIRFOxCJEJvpTcYMhbOoir28QwfJxlsA6WwvAMAA0MT1FZp8AykPUM85z5H/Khfnqdj2hm+v0qeS03T0+GsJgBo17mMc+sdB4Lbf/qTx+iYuleoLW9c3qyW+bjhoXDWXiHHT7msRfqhrfPP7PU5LqMtL4c/s5qt0TGTt/B74MxoJGvP+We9dJHPVWE9LGEOzEQyFSvhrMJ2RL3UnV2IRFCwC5EICnYhEkHBLkQiKNiFSISersZnDCjkwteXSo0nGGRJC6J2pD5apcGTGbJ5nlTRV+Crrfl82I9CP2+DNDLME3LOL/BV/MpMeFUdAKYO3kRtsxfCdeHe9Rvvo2PKC+eo7dQrvLXSWpknfuSy4fkfGeG19YzUJwSAuVnu4y/eiCTC9IXnf3iaKzmT4xEfI6qALfLPemyJh9rM1Hhw+4FRfg6cfDGc8FSr8iQv3dmFSAQFuxCJoGAXIhEU7EIkgoJdiERQsAuRCBtKb2Z2EMBfoNOS2QEcd/evmtnnAfwBgIXur37W3R+NHixnmJ4MX18aly7RcdVWWJJZ47kM8AxvDZWLJGMMD/PkgwJprVRd4zXoSpGaYKhz24mf/pTabryVS3Znz4YlmUykXl9/H68ll43Im6USl5rWymHprVrlkmgz0gJssMT9uPc9t1BbkSTkNLO8tl6rwZNWqme49JZZLVLbVP8Qtb3nlneFx4zyLuhPzb0e3N5s8Pe1GZ29CeBP3f1nZjYE4Ckz+3HX9hV3/6+b2IcQYpfZTK+3OQBz3derZvYSgJmddkwIsb28o2d2MzsE4D0Anuhu+rSZPWtmD5kZb40qhNh1Nh3sZjYI4PsAPuPuKwC+BuAwgCPo3Pm/RMYdM7MTZnZipcKfyYQQO8umgt3M8ugE+jfd/QcA4O7z7t5y9zaArwO4OzTW3Y+7+1F3Pzrczyt5CCF2lg2D3cwMwDcAvOTuX75i+74rfu2jAJ7ffveEENvFZlbj3wfgkwCeM7Nnuts+C+ATZnYEHTnuNIA/3GhHhYLhuoPhu/uIcdni5JmwFDK/wLPX6i0u1QwO8re9VuEZVK12Obg9G7lmLi5wSXG1zGWS9Qb3I+vcNjQYXjqZP79Ix5xd43JS27lkNz3JZUprh7OvlpZ5vbi+Af6ZjY5w6aqQ5fNfqxMJNsflxrUa31+9HGl51ebjbjq4l9r27w3P45mzXGK9tBCOiWakhdZmVuP/HkDoE49q6kKIawv9BZ0QiaBgFyIRFOxCJIKCXYhEULALkQg9LTiZzRmGx0jmGJESAGBsKhs2DPCigRfneQHL9Uj7pFyBFxtkw9oNnmHXaHE/Lle5DDUQyfJar3CprLoeLjhZj/jYitjcydwDKK9E2j8Nhwt3Dg/z4pzVKt/fxUt8rgYHefadZcL3M2ty2baQ40VH+7hCjEKBz9Whmw5RW7US9uXxx1+kY5595UJ4X+tcztWdXYhEULALkQgKdiESQcEuRCIo2IVIBAW7EInQU+nNzJArhg9ZHOa57uOD4WtSrsplrXyJZ/+sRPpuocWvf6XiVHhInh+rVeP90Ar93I98js9HNsslx5qHfak3uNzokcw24woVvM4lwBYx5SPZZihwuXF5iUtv1TrvbzYyGpZSc0SSA4BMZO4r4NLW/MVValuKZDiuroWzGP/2717mxyIq5Xpd0psQyaNgFyIRFOxCJIKCXYhEULALkQgKdiESoafSW7ttKLOCfdlBOm5wIKzj5EtcFxqIpCeNjHCprLzCe5GVV8IFAMuVSNbbOrcNFXjBxiLpKwcAzRqXHHO58PW7ELms5/t4tpYZH9gfKdyZIaZmi0tDhVKkB98olxsXF7nktUqkyOFxPveVSM+5V0/zAqIvP3eG2qbHeTbl9AHy3jL8PJ0gBTjnV7kMqTu7EImgYBciERTsQiSCgl2IRFCwC5EIG67Gm1kRwOMA+rq//z13/5yZ3QDg2wD2AHgKwCfdPdqmtV4Hzr4RttWW+er50GR4BbdYiiRA8MV9jI/zt11e43XQlpfDtqVLPHFiiS/eItvmq+Bt50pDq8VX+NEO22JXdcvwRJhsjs9VNZI05GTRPU/aQgFAs8JbVLUi9elakeSa5XJ4HOsKBQCLEUXm9En+gS5fWqO2+ho/4N6RcGuo266foWOYi6+eX6FjNnNnrwH4LXe/E532zPeZ2T0AvgjgK+5+E4AlAJ/axL6EELvEhsHuHd7saJjv/nMAvwXge93tDwP4yI54KITYFjbbnz3b7eB6AcCPAbwGYNn9/31ZOwuAf+cQQuw6mwp2d2+5+xEABwDcDeDXNnsAMztmZifM7MTlMi92IITYWd7Rary7LwP4CYB/BWDUzN5cvTkAYJaMOe7uR9396MhgpMK+EGJH2TDYzWzSzEa7r0sAfhvAS+gE/b/t/toDAH60U04KIbbOZhJh9gF42Myy6Fwcvuvuf2VmLwL4tpn9ZwBPA/jGRjtyy6GVnwjaGoWjdFytHU78yDTDrY4AoDjC5aTRSf4NYyzDEzXGK+HEhOVF3i5o+SKX16prfPpbTS7nwfk1ut0M+7he5Y9QhUKk3l2O+7+6zhM1quSRLR9RZ4cy4eQOAGhnuKTUaPB57BsIS5jFPK93N1rgPt6IUWp79528DdWtd9xJbYduuim4/e57uNx49lw5uP0fXuMxsWGwu/uzAN4T2H4Kned3IcQvAfoLOiESQcEuRCIo2IVIBAW7EImgYBciEcwj2VXbfjCzBQBv5r1NAOA6Qe+QH29FfryVXzY/rnf3yZChp8H+lgObnXB3Lq7LD/khP7bVD32NFyIRFOxCJMJuBvvxXTz2lciPtyI/3sqvjB+79swuhOgt+hovRCLsSrCb2X1m9s9mdtLMHtwNH7p+nDaz58zsGTM70cPjPmRmF8zs+Su2jZvZj83s1e7/Y7vkx+fNbLY7J8+Y2Yd74MdBM/uJmb1oZi+Y2Z90t/d0TiJ+9HROzKxoZv9kZj/v+vGfuttvMLMnunHzHTOLpEYGcPee/gOQRaes1Y0ACgB+DuD2XvvR9eU0gIldOO5vArgLwPNXbPsvAB7svn4QwBd3yY/PA/j3PZ6PfQDu6r4eAvAKgNt7PScRP3o6JwAMwGD3dR7AEwDuAfBdAB/vbv/vAP7onex3N+7sdwM46e6nvFN6+tsA7t8FP3YNd38cwNvrJt+PTuFOoEcFPIkfPcfd59z9Z93Xq+gUR5lBj+ck4kdP8Q7bXuR1N4J9BsCV7S53s1ilA/gbM3vKzI7tkg9vMu3uc93X5wFM76IvnzazZ7tf83f8ceJKzOwQOvUTnsAuzsnb/AB6PCc7UeQ19QW697v7XQB+F8Afm9lv7rZDQOfKjs6FaDf4GoDD6PQImAPwpV4d2MwGAXwfwGfc/S2laXo5JwE/ej4nvoUir4zdCPZZAAev+JkWq9xp3H22+/8FAD/E7lbemTezfQDQ/f/Cbjjh7vPdE60N4Ovo0ZyYWR6dAPumu/+gu7nncxLyY7fmpHvsd1zklbEbwf4kgJu7K4sFAB8H8EivnTCzATMbevM1gN8B8Hx81I7yCDqFO4FdLOD5ZnB1+Sh6MCdmZujUMHzJ3b98hamnc8L86PWc7FiR116tML5ttfHD6Kx0vgbgP+ySDzeiowT8HMALvfQDwLfQ+TrYQOfZ61Po9Mx7DMCrAP4WwPgu+fE/ATwH4Fl0gm1fD/x4Pzpf0Z8F8Ez334d7PScRP3o6JwDuQKeI67PoXFj+4xXn7D8BOAngfwPoeyf71V/QCZEIqS/QCZEMCnYhEkHBLkQiKNiFSAQFuxCJoGAXIhEU7EIkgoJdiET4vyrWWZ/xQ9u6AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 小栗子 看看数据\n",
    "x, y = next(iter(train_ds))\n",
    "print(idx2label.lookup(y))\n",
    "plt.imshow(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UkXcPZMzja3n"
   },
   "source": [
    "我们在DataLoader中指明定义好的图像增广操作。在训练时，我们仅用验证集评价模型，因此需要保证输出的确定性。在预测时，我们将在训练集和验证集的并集上训练模型，以充分利用所有标注的数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MPNgNOBFjDpN"
   },
   "outputs": [],
   "source": [
    "train_iter = train_ds.map(transform_train)\n",
    "valid_iter = valid_ds.map(transform_test)\n",
    "train_valid_iter = train_valid_ds.map(transform_train)\n",
    "test_iter = test_ds.map(transform_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "pmybflAINl7N"
   },
   "outputs": [],
   "source": [
    "batch_size = 8\n",
    "train_iter = train_iter.shuffle(1024).batch(batch_size)\n",
    "valid_iter = valid_iter.shuffle(1024).batch(batch_size)\n",
    "train_valid_iter = train_valid_iter.shuffle(1024).batch(batch_size)\n",
    "test_iter = test_iter.shuffle(1024).batch(batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 300
    },
    "colab_type": "code",
    "id": "gY04DJw7jzL8",
    "outputId": "8907b1bd-c7fe-4e0c-dd54-b5bd11824b77"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(b'ship', shape=(), dtype=string)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f5e9ee3a588>"
      ]
     },
     "execution_count": 21,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAbOklEQVR4nO2dW4xkV3WG/1WnLn2fu8eDPbEHYpQ4TrBRyyKCEAICOQjJIEUWfkB+sBgUYSlI5MFypOBIeYAogHiIiIZgYSKCcQKWrchKcCwkixdDmxjb2ICNL3iGuXlmeqavdTsrD1UTta39r+6u7q4evP9PGk31Wb3PWWefs+pU77/WWubuEEK8+alstwNCiOGgYBciExTsQmSCgl2ITFCwC5EJCnYhMqG6kcFmdhOArwAoAPyLu38++v3G6KiPT02xnfHjDGABBpMUB5Iif+vVS34C0QxHGL2ewWQFpuD2QCW6d4gtuszdsqQ2D2yDUqGzHN1Y6TFz83NYXl5OGgcOdjMrAPwTgA8COArgx2b2kLs/y8aMT03hA7femrRVC+5KtSiYE3RMFLRlt8ttHW6j+wzfIKJwid7GgnMLbgLmY8X5TWrOz9mCc6tU+AdDes3Q4fsLfKwZP9ZYo0FtRVFLbm+1uR8LS0vU1lxuUlt0X7HZAIBGNX3vR3PP3sQeePBBOmYjH+NvBPCCu7/o7i0A9wG4eQP7E0JsIRsJ9isAvLri56P9bUKIS5AtX6Azs8NmNmNmM83g45EQYmvZSLAfA3Bwxc9X9re9Dnc/4u7T7j7dGB3dwOGEEBthI8H+YwDXmNkhM6sD+DiAhzbHLSHEZjPwary7d8zsDgD/jd5i4z3u/rNoTMcdZ5rp1cxGg688NshbEpd3AA9WdjudNrU1W9wWrbYyIh+jFfdQoSrXvxpv4PNRDLgaXwSr8TWywgwPVuOD61IN/Gh2x7gf9fRKfbPD97fc5n60Otz/bjDOSj6uwZSG4Do3iNpRBtLghnR2d38YwMMb2YcQYjjoG3RCZIKCXYhMULALkQkKdiEyQcEuRCZsaDV+/RisqCctlSiZoZYeY0XwXtWNspN4WkKn5Pu0yvoznqKMrEqQ3BEm8gQ2lpVlYWZbJMtREwqa7AJUKix5iUtQ3uTfsGw1W9yRIHmJ5MGgE/jeaYxQm7MdAkCQXIPOMjUxOc9Lfl5Vclt5lB3ITUKINxMKdiEyQcEuRCYo2IXIBAW7EJkw3NV4d1gzvcLowbJvu0tsUa2fki9Ldpv8Pa5sBTa2aB0lrQTJIt2C+1gGOy2DBAla6iqskxesxgeKQTVQLkpP2yyYj6i8V7vFfSS3FACgYmSfbDuAMlBynN2LACyYD3SDuWqn99lpBUk3y4vp7UGylp7sQmSCgl2ITFCwC5EJCnYhMkHBLkQmKNiFyIThSm/dEuWFuaRpMUgiaFXSbno1nSADABbocpG00m7zhAs2LkpasUqQmRAk8kTSW9SAhjcSCovaUVP0NCgCa0E8iRJyqpHk1QxkykBGq5Tp6xn54V1+L1oge1ai+6DD7ys6LpD55hYuJLeXQfKMnuxCZIKCXYhMULALkQkKdiEyQcEuRCYo2IXIhA1Jb2b2MoA5AF0AHXefDgeUHWDhfNpWD2QoIl9ZNagXV/CadhZJPBakUNHWSoEcE7VIYnXasEprqAB2vKAUHrpBS6Oo/VMlUPMqxP9IpqxG+6vxunBhhiDNLIwy24JacsG9E82HFTzU6kyCDerWzZ0/mdwezi+1rJ0/c/fXNmE/QogtRB/jhciEjQa7A/i+mT1hZoc3wyEhxNaw0Y/x73H3Y2Z2GYBHzOzn7v7Yyl/ovwkcBoDG6OgGDyeEGJQNPdnd/Vj//1MAHgBwY+J3jrj7tLtP1+v8u+xCiK1l4GA3s3Ezm7z4GsCHADyzWY4JITaXjXyM3w/gAetpOlUA/+bu/xUNMDgaRTtpm9o9ScdVduxJbi/GgjEFl2raQSuh5cV5amNVG6MWT1GLpEh6i/ZpUUspss9KkH1XBtIbr7LJ5TWA13OM2hNZZIyS9gIZjZ1ZGcxhN2gn5eFcDZhZyJS3xXRmGwAszx5Pbo8Keg4c7O7+IoB3DDpeCDFcJL0JkQkKdiEyQcEuRCYo2IXIBAW7EJkw1IKTlWqB0d1TSdvBt15Fx43t3JvcbkEmlAcFJ5uLPJuoucS/5Re2SyNEMpkHchhplbaqHyxbLsrMi+S1qLplJL0V5Lyj7K9u4GPUwyzK9HLiYze4PzptfqxuKy0dA/GTM5IHu0SWGx2v8THn9ye3V6t8jJ7sQmSCgl2ITFCwC5EJCnYhMkHBLkQmDHU13qs1tPfuS9qW9uyi49r19ApjpRqsZgcrzK0gYaFd8DRcts9opbgbtOPpGvcxStQoA//peQeJGJEt6KyEWrCy3qil57EaFJprdfhKdwd8HqN6fWbp51kZrOCXUTG52vpbXgHxNWstpc+7KPl81EfSqlGUCKMnuxCZoGAXIhMU7EJkgoJdiExQsAuRCQp2ITJhuNJbrYbO5QeStrOTY3RcvZGWcWo1/qV/C6SOTnDW86cXqK27tJTcvjw/R8e0Wjzpphr4EbU7apVB/bRmM32sNq+7FyXJFIEUObWDy6W7JnfSozEW5xepLWpf1QgmsiBSVFHl51UJJMXo6ciOBQDNJr8PTp9P15o7czxdZw4AinPnktujVl56sguRCQp2ITJBwS5EJijYhcgEBbsQmaBgFyITVpXezOweAB8BcMrdr+tv2w3gOwCuBvAygFvcPa0FrNxXUUWxO11PrlUE2UTVdL2wosbriKHLM7mWAxnk3K+PUltlIS2xnT/FJZL2Mm8nNTXJ5bXKrvQ8AcCcc9modW42uX3kQno7AERJXvWgxdbUwUPUtmPvW5LbW20uDc2e4RJmJA9and/GtVojuX1kjN9v7tzHMsgQrJPsTADwpbQkCgB+Jn1tFk6e5MdaTkvEZZBluZYn+zcA3PSGbXcCeNTdrwHwaP9nIcQlzKrB3u+3fvYNm28GcG//9b0APrrJfgkhNplB/2bf7+4XP7ueQK+jqxDiEmbDC3TeK9pN/6Ays8NmNmNmM63ga6VCiK1l0GA/aWYHAKD//yn2i+5+xN2n3X26PsEXe4QQW8ugwf4QgNv6r28D8ODmuCOE2CrWIr19G8D7AOw1s6MAPgfg8wDuN7PbAbwC4Ja1Hc5RJcUZffY8HVVppCWqCsmGA4D2Mpc6zv3qZWqbf+klaptspbOy6mdO0zEFkUgAYGR5nNq6YzwLsNvg79HLi+kMqu7RV+mYIpApu5M8s225yltlNXensxsvLPDMtvmT9AMiKkH7pHKES5g2SgozLvIswOUmt3W7XJYbD67ZUnDeF15NS7fNs4GabSSOgmzPVYPd3W8lpg+sNlYIcemgb9AJkQkKdiEyQcEuRCYo2IXIBAW7EJkw1IKTncVFnHzyyaRtZCSdnQQAS0FxQMbyYro4JADMnjhBbVNBv7FxT9smgv5fSzzBDuUilwfbQYVFm+ISTzmSnqsLC1wCrAQZWbUlLjVV61w6nJjcndz+8+dfoGMagZR66NBV1Lb/8suobbROes5VeIZa2Qky7KgFqAWFL189z789+tqJ3yS3nz3Ns95G6+l7TgUnhRAKdiFyQcEuRCYo2IXIBAW7EJmgYBciE4YrvS03ce4Xv0zapiam6Lg26XvmQRHCTlDY0IIMpNExnkHVJtlhi0GW1LkmlwDR4NLVFJGMAGDHKJfeMJre56mgf1mzzYtidsD9XwyKkRSWnquxoCjjwat+h9r+4Lprqe3y/fuorUb6r1Wci2hBTUlY9HwMCncuzAWFW0gmXWuZz31Rpq9nr5ZMGj3ZhcgEBbsQmaBgFyITFOxCZIKCXYhMGOpqfGGGHZX0KvNUwVfBJ0gdNNYWCogTArrjPNmlYfz978xcukbahRG+Om47ucowvj+dLAIARZBUMf8Sryd3/qVfJ7e3m/ycjaxYAwiXpssgaWhqLF377U//5N10zI49e6ht737eDmtqnM9/laxOt4MWYK02b6FUrXE1oRIkbNUbwThL+2hB3T1Qm1bjhcgeBbsQmaBgFyITFOxCZIKCXYhMULALkQlraf90D4CPADjl7tf1t90N4JMALvY9usvdH15tX2MjY/jDa29I2hojvJXQ2FRavioqgfRWctmi0+LSSjNIXFl8LT1dHfAabj4WtKgqufxz4ddHqW3uVd4maem12bQfgRRZBHJjlFgB5/M4Uk9fm+uu/X2+u4Jfz/lFnqxz5kK65RUA+GI66and4nO/HMyVR/LaGG9cOr/AE2FG6un7aiyoy8hLFPIEn7U82b8B4KbE9i+7+/X9f6sGuhBie1k12N39MQBnh+CLEGIL2cjf7HeY2VNmdo+Z8VafQohLgkGD/asA3gbgegDHAXyR/aKZHTazGTObWVriRSOEEFvLQMHu7ifdvevuJYCvAbgx+N0j7j7t7tOjUYUVIcSWMlCwm9mBFT9+DMAzm+OOEGKrWIv09m0A7wOw18yOAvgcgPeZ2fXopdi8DOBTazlYfWQUh95+HTkOf9+pMIktkIXKwNYuuWQ03+ZtksY6aYln7uzp5HYAWDp9ntvO83XP5hm+z/I8l5pqpJ4ZnUMAXVJbD4ilt06b1947fiwtHZ46cZyOmdjFl36OvZrO5gOA117hWYA+m57/ouTyWjOQbds1LoeN7uGZec0lLuk2iPQ2HtRDbJK6h0HXsNWD3d1vTWz++mrjhBCXFvoGnRCZoGAXIhMU7EJkgoJdiExQsAuRCUMtOAmrwOoTSVMlSK4yIv9E7Z94QT6gJK2JAGCpxSWSZdI2qjvL5bqJkmshUyWXcWaNX5oFItUAgBek3VFQu3A5kCKDLknodnnByZdefDG53Sr8nPdcvp/aTp86QW0nX3qJ2ooL6WtTJRIlALSCIpvj+y+ntrdffoDadpACnABwnLT66nb5dYkyDhl6sguRCQp2ITJBwS5EJijYhcgEBbsQmaBgFyIThiu9weDk/cWDQnkgUlkkvUWiXKQnlU0uydQ8PV17JniPsp3jaakRADpLPHut0+USYMe55EULS7aC3nc1nhHHpDwAqATjWuR4L/zil3TMudlz1Fat82u22OTFKBfOp4tzWlBUMlBLcfluXlRyx45xapskRVMBoFpLz3GnywuZsphQrzchhIJdiFxQsAuRCQp2ITJBwS5EJgx5NR5g7Wk8KJ7FV+qDZdNgVdKC97i68fY+O8ZIjbSgaO7EBF+hLcud1LYEvlrcCU576UK6rl2n5PNRK3its0q1xsfV+bg6GTc/yxWIfft4DbqdO7ht9iy/jU+WaVXDgkSpSsH3txjUKDx3/jVqYyvuANAmyouDqy5F4CNDT3YhMkHBLkQmKNiFyAQFuxCZoGAXIhMU7EJkwlraPx0E8E0A+9HTs464+1fMbDeA7wC4Gr0WULe4O89k6BMmqHAv1r23MpDlAhUK0ftfvZGuIxa1VmqMcumtMcoTaIo6r9VmFX7Zjqe7LmFxnrehqlb4OTcaPJGnFshyrYX08aoFlxRHR4Pb0YPElS5vQ1UhxQ2D/B5Uq/zeaS5z6e2VF5+ntkUiiQLA4kI6bDw4Z1SIRBxI2Gt5sncAfNbdrwXwLgCfNrNrAdwJ4FF3vwbAo/2fhRCXKKsGu7sfd/ef9F/PAXgOwBUAbgZwb//X7gXw0a1yUgixcdb1N7uZXQ3gBgCPA9jv7hdbcp5A72O+EOISZc3BbmYTAL4L4DPu/rrvPHqvr2/yjyMzO2xmM2Y2szDH/24UQmwtawp2M6uhF+jfcvfv9TefNLMDffsBAMmSIO5+xN2n3X16fHLHZvgshBiAVYPdzAy9fuzPufuXVpgeAnBb//VtAB7cfPeEEJvFWlJn3g3gEwCeNrMn+9vuAvB5APeb2e0AXgFwy1oOyLLbPOozxCS2KFMu2J8bf4+r1ngm1+hEerqiNj1W8Cy6SpWny03u5OOuPMSlvupIetyJ46/QMa0ml66mJvZRWz2Q7E7MnUlu74LXVTt58jfUNjLHr8vSIq/XV62mr1l041erfH4RtIY6fTpd7w4AFudnqa3TXk5uL4I2X2W5/q/IrBrs7v5DcKH7A+s+ohBiW9A36ITIBAW7EJmgYBciExTsQmSCgl2ITBh6wUlaPDKU3jbxOABswCy1OpEA2y0uXUUZamUw/RZklI3v4HLYgWr6/bs2ks7YA4DlxbT0AwCT4zwzD20uo83PpmW0c0RmAoBWl8taFrTlqrAMMACj4+kvclUtKEga3IpFjWcjVqrcjzY/NVTq6WtTHeE+djrk/ghkZT3ZhcgEBbsQmaBgFyITFOxCZIKCXYhMULALkQlDld4cgA9QcZLLdXxn0WEi6a0SFHpkx4vktUgChHE/Iv8rVZ4BNjaZlsr2B9l87WUuHVaDPnDNxTlq23nZweT2bjAfjeC8iiDDDoGPHVKM0sAzFRH0gSuq/P4YGeP1GkYaXJYrltNFPTs1ns1XsbQfxfMv8zHUIoR4U6FgFyITFOxCZIKCXYhMULALkQlDT4ShbZmC5edBWkbFq/6D1aejewtrlg3W8CqG+2iWXpkeIbXpAGC0xn3sOl+19uDK7Nl/KLm9UuMJOejwZBfr8mM1unxcq9tObi89Sl6iJpTB3BeNSWrbtY+3VRgjyUFjgUrSKNKJMM/OzNAxerILkQkKdiEyQcEuRCYo2IXIBAW7EJmgYBciE1aV3szsIIBvoteS2QEccfevmNndAD4J4HT/V+9y94dX219JFZRADhuSXAdEKRCbjwVehs2wotZWnr6kQcm10I9KwWWtxlg6gaO3z7Q0VBY8kaS9NE9ttUhuLLj0ybokleDnVTq3LQVymBuXN8d2vYXbyNWeDCRFb6aTZAoiyQFr09k7AD7r7j8xs0kAT5jZI33bl939H9ewDyHENrOWXm/HARzvv54zs+cAXLHVjgkhNpd1/c1uZlcDuAHA4/1Nd5jZU2Z2j5nt2mTfhBCbyJqD3cwmAHwXwGfc/QKArwJ4G4Dr0Xvyf5GMO2xmM2Y2szB3YRNcFkIMwpqC3cxq6AX6t9z9ewDg7ifdvevuJYCvAbgxNdbdj7j7tLtPj09ObZbfQoh1smqwm5kB+DqA59z9Syu2H1jxax8D8MzmuyeE2CzWshr/bgCfAPC0mT3Z33YXgFvN7Hr0VK6XAXxqtR05DKWn31+i5LDBZLRIyguONVCW2qAeBi2qwqMNdm50TKDLRTIfjN8+RT19nSd3XEbHdEZ51lg1qBtYBFmHLIuxDGaqDDL9xtqBHEbubQCoNiKZMj2uEmQBtlpEJA56V61lNf6HSN97q2rqQohLB32DTohMULALkQkKdiEyQcEuRCYo2IXIhOEWnHQHPC0ZxJLR+nPRwv0RH1YfmbYNJtcBiGStKLMt8JG5YlGrrGA+wmNRC1CppP1vNHjWW63GM7bC9k/RNJJ5LMPMQX5m1SJKH+T7rAQ2dgJW2dxipXqyC5EJCnYhMkHBLkQmKNiFyAQFuxCZoGAXIhOG3uvNSEZRlMnFRZ5BpYkoxW4AWW79at3qDCjn8QS2SEILzjlMR1y/jxZIUEXBb8doXCh9UilyMNmziJIAuSn0sSQXLboubiwzLygeSi1CiDcVCnYhMkHBLkQmKNiFyAQFuxCZoGAXIhOGLr0xBSJSQui+QrEjGjhAXzmAOj9o1lvcf23QopLrlwejQomD6oqDzEicFzZYFuCm9+6LTiyQB8tIeqPbg/0NMMN6sguRCQp2ITJBwS5EJijYhcgEBbsQmbDqaryZjQB4DECj//v/4e6fM7NDAO4DsAfAEwA+4e6taF8OoE3a8XTBW/iw9cqBa78FRPvc7OPFWsL668ytNo6PGPC8BlFQQucHbYcV2IjkESkhIQMKQCFleqdeRglKLF6COnhrcKUJ4P3u/g702jPfZGbvAvAFAF92998FcA7A7WvYlxBim1g12L3HfP/HWv+fA3g/gP/ob78XwEe3xEMhxKaw1v7sRb+D6ykAjwD4FYBZd7/YZvIogCu2xkUhxGawpmB39667Xw/gSgA3Avi9tR7AzA6b2YyZzSzNnx/QTSHERlnXary7zwL4AYA/BrDT7P8bdF8J4BgZc8Tdp919enRix4acFUIMzqrBbmb7zGxn//UogA8CeA69oP+L/q/dBuDBrXJSCLFx1pIIcwDAvWZWoPfmcL+7/6eZPQvgPjP7ewD/C+DrazlgSSSP6Iv9NkjbpUhCG1A/2WylL2q7tPkH46Yo4WLgnQ40IrgHBm6HNUCdvMAWypRRYlPgf4XYKkSS640hoRscZ9Vgd/enANyQ2P4ien+/CyF+C9A36ITIBAW7EJmgYBciExTsQmSCgl2ITLCtyByjBzM7DeCV/o97Abw2tINz5MfrkR+v57fNj6vcfV/KMNRgf92BzWbcfXpbDi4/5EeGfuhjvBCZoGAXIhO2M9iPbOOxVyI/Xo/8eD1vGj+27W92IcRw0cd4ITJhW4LdzG4ys1+Y2Qtmdud2+ND342Uze9rMnjSzmSEe9x4zO2Vmz6zYttvMHjGz5/v/79omP+42s2P9OXnSzD48BD8OmtkPzOxZM/uZmf1Vf/tQ5yTwY6hzYmYjZvYjM/tp34+/628/ZGaP9+PmO2ZWX9eO3X2o/wAU6JW1eiuAOoCfArh22H70fXkZwN5tOO57AbwTwDMrtv0DgDv7r+8E8IVt8uNuAH895Pk4AOCd/deTAH4J4Nphz0ngx1DnBL0s24n+6xqAxwG8C8D9AD7e3/7PAP5yPfvdjif7jQBecPcXvVd6+j4AN2+DH9uGuz8G4OwbNt+MXuFOYEgFPIkfQ8fdj7v7T/qv59ArjnIFhjwngR9DxXtsepHX7Qj2KwC8uuLn7SxW6QC+b2ZPmNnhbfLhIvvd/Xj/9QkA+7fRlzvM7Kn+x/wt/3NiJWZ2NXr1Ex7HNs7JG/wAhjwnW1HkNfcFuve4+zsB/DmAT5vZe7fbIaD3zo7Buh5vBl8F8Db0egQcB/DFYR3YzCYAfBfAZ9z9wkrbMOck4cfQ58Q3UOSVsR3BfgzAwRU/02KVW427H+v/fwrAA9jeyjsnzewAAPT/P7UdTrj7yf6NVgL4GoY0J2ZWQy/AvuXu3+tvHvqcpPzYrjnpH3vdRV4Z2xHsPwZwTX9lsQ7g4wAeGrYTZjZuZpMXXwP4EIBn4lFbykPoFe4EtrGA58Xg6vMxDGFOzMzQq2H4nLt/aYVpqHPC/Bj2nGxZkddhrTC+YbXxw+itdP4KwN9skw9vRU8J+CmAnw3TDwDfRu/jYBu9v71uR69n3qMAngfwPwB2b5Mf/wrgaQBPoRdsB4bgx3vQ+4j+FIAn+/8+POw5CfwY6pwA+CP0irg+hd4by9+uuGd/BOAFAP8OoLGe/eobdEJkQu4LdEJkg4JdiExQsAuRCQp2ITJBwS5EJijYhcgEBbsQmaBgFyIT/g8/K4ZZbS+8eQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 小栗子 看看数据\n",
    "x, y = next(iter(train_iter))\n",
    "print(idx2label.lookup(y[0]))\n",
    "plt.imshow(x[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vBHRWzuAkudg"
   },
   "source": [
    "## 9.12.4. 定义模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4toBYEgMkvcA"
   },
   "source": [
    "与“残差网络（ResNet）”一节中的实现稍有不同，这里基于HybridBlock类构建残差块。这是为了提升执行效率。\n",
    "\n",
    "tensorflow 好像没有这种？ 还用原来那个定义好了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "B_TAHxuGklAv"
   },
   "outputs": [],
   "source": [
    "class Residual(keras.Model):\n",
    "    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):\n",
    "        super(Residual, self).__init__(**kwargs)\n",
    "        self.conv1 = keras.layers.Conv2D(num_channels,\n",
    "                        padding='same',\n",
    "                        kernel_size=3,\n",
    "                        strides=strides)\n",
    "        self.conv2 = keras.layers.Conv2D(num_channels,\n",
    "                        kernel_size=3,\n",
    "                        padding='same')\n",
    "        if use_1x1conv:\n",
    "            self.conv3 = keras.layers.Conv2D(num_channels,\n",
    "                            kernel_size=1,\n",
    "                            strides=strides)\n",
    "        else:\n",
    "            self.conv3 = None\n",
    "        self.bn1 = keras.layers.BatchNormalization()\n",
    "        self.bn2 = keras.layers.BatchNormalization()\n",
    "    \n",
    "    def call(self, x):\n",
    "        y = keras.activations.relu(self.bn1(self.conv1(x)))\n",
    "        y = self.bn2(self.conv2(y))\n",
    "        if self.conv3:\n",
    "            x = self.conv3(x)\n",
    "        return keras.activations.relu(y + x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tD4Hz1deEhlN"
   },
   "source": [
    "下面定义ResNet-18模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tmEGFSgiD0GF"
   },
   "outputs": [],
   "source": [
    "def resnet18(num_classes):\n",
    "    net = keras.Sequential()\n",
    "    net.add(keras.layers.Conv2D(64, kernel_size=3, strides=1, padding='same'))\n",
    "    net.add(keras.layers.BatchNormalization())\n",
    "    net.add(keras.layers.Activation('relu'))\n",
    "\n",
    "    def resnet_block(num_channels, num_residuals, first_block=False):\n",
    "        blk = keras.Sequential()\n",
    "        for i in range(num_residuals):\n",
    "            if i == 0 and not first_block:\n",
    "                blk.add(Residual(num_channels, use_1x1conv=True, strides=2))\n",
    "            else:\n",
    "                blk.add(Residual(num_channels))\n",
    "        return blk\n",
    "    \n",
    "    net.add(resnet_block(64, 2, first_block=True))\n",
    "    net.add(resnet_block(128, 2))\n",
    "    net.add(resnet_block(256, 2))\n",
    "    net.add(resnet_block(512, 2))\n",
    "    net.add(keras.layers.GlobalAveragePooling2D())\n",
    "    net.add(keras.layers.Dense(num_classes, activation='softmax'))\n",
    "    return net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "cotiCsRrH1DX"
   },
   "source": [
    "CIFAR-10图像分类问题的类别个数为10。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NXGCVHZaHzkG"
   },
   "outputs": [],
   "source": [
    "def get_net():\n",
    "    num_classes = 10\n",
    "    net = resnet18(num_classes)\n",
    "    return net\n",
    "net = get_net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 170
    },
    "colab_type": "code",
    "id": "PW2qZFwc_eSG",
    "outputId": "f8ee312c-187c-402a-fc7b-e56a05ea1722"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv2d output_shape: (1, 32, 32, 64)\n",
      "batch_normalization output_shape: (1, 32, 32, 64)\n",
      "activation output_shape: (1, 32, 32, 64)\n",
      "sequential_1 output_shape: (1, 32, 32, 64)\n",
      "sequential_2 output_shape: (1, 16, 16, 128)\n",
      "sequential_3 output_shape: (1, 8, 8, 256)\n",
      "sequential_4 output_shape: (1, 4, 4, 512)\n",
      "global_average_pooling2d output_shape: (1, 512)\n",
      "dense output_shape: (1, 10)\n"
     ]
    }
   ],
   "source": [
    "x = tf.random.uniform(shape=(1,32,32,3))\n",
    "for layer in net.layers:\n",
    "    x = layer(x)\n",
    "    print(layer.name, \"output_shape:\", x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "e7ki445fJdqx"
   },
   "source": [
    "## 9.12.5. 定义训练函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "g9IV6K1vJgMB"
   },
   "source": [
    "我们将根据模型在验证集上的表现来选择模型并调节超参数。下面定义了模型的训练函数train。我们记录了每个迭代周期的训练时间，这有助于比较不同模型的时间开销。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3FLFzBc2JIpp"
   },
   "outputs": [],
   "source": [
    "lr = 0.1\n",
    "lr_decay = 0.01\n",
    "\n",
    "def scheduler(epoch):\n",
    "  if epoch < 10:\n",
    "    return lr\n",
    "  else:\n",
    "    return lr * tf.math.exp(lr_decay * (10 - epoch))\n",
    "\n",
    "callback = tf.keras.callbacks.LearningRateScheduler(scheduler)\n",
    "\n",
    "net.compile(optimizer=keras.optimizers.SGD(learning_rate=lr, momentum=0.9),\n",
    "        loss='sparse_categorical_crossentropy',\n",
    "        metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "t4PdOAFAMkHu"
   },
   "source": [
    "## 9.12.6. 训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "a_g3lTJIMll1"
   },
   "source": [
    "现在，我们可以训练并验证模型了。简单起见，这里仅训练1个迭代周期。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 71
    },
    "colab_type": "code",
    "id": "QJKjycn7JkYn",
    "outputId": "e2cd7ae5-2ef9-404f-ca4b-b96e6ab5b508"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/12 [==============================] - 9s 770ms/step - loss: 13.8743 - accuracy: 0.0556 - val_loss: 4652004352.0000 - val_accuracy: 0.1000 - lr: 0.1000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f5e98d83668>"
      ]
     },
     "execution_count": 27,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(train_iter,\n",
    "    epochs=1,\n",
    "    validation_data=valid_iter,\n",
    "    callbacks=[callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lFFNhqEa1a5t"
   },
   "source": [
    "## 9.12.7. 对测试集分类并在Kaggle提交结果"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MSNbSTH71iuS"
   },
   "source": [
    "得到一组满意的模型设计和超参数后，我们使用所有训练数据集（含验证集）重新训练模型，并对测试集进行分类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "OF5zlPOINHrt",
    "outputId": "0fea80a9-e6ac-4e6d-aa4a-5e4abd7ec4a8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13/13 [==============================] - 9s 730ms/step - loss: 4.6589 - accuracy: 0.0700 - lr: 0.1000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f5e98b51ba8>"
      ]
     },
     "execution_count": 28,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.compile(optimizer=keras.optimizers.SGD(),\n",
    "        loss='sparse_categorical_crossentropy',\n",
    "        metrics=['accuracy'])\n",
    "\n",
    "net.fit(train_valid_iter,\n",
    "    epochs=1,\n",
    "    callbacks=[callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bFKhqrBw5aGi"
   },
   "source": [
    "执行完上述代码后，我们会得到一个submission.csv文件。这个文件符合Kaggle比赛要求的提交格式。提交结果的方法与“实战Kaggle比赛：房价预测”一节中的类似。"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "TekJFvqQHXTC",
    "vBHRWzuAkudg",
    "e7ki445fJdqx",
    "t4PdOAFAMkHu",
    "lFFNhqEa1a5t"
   ],
   "name": "9.12_kaggle_cifar10.ipynb",
   "provenance": []
  },
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": true,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
